import torch
import torch.nn as nn


class Model(nn.Module):
    def forward(self, x):
        return x + torch.relu(x)

model     = Model()
# opt_model = torch.compile(model)
opt_model = torch._dynamo.optimize(model)


x = torch.ones(2,3 ,dtype=torch.float32)
y = opt_model(x)

print(y)